from DEAD.AutoDecoder.Network import AutoDecoderNet
from DEAD.AutoDecoder.Config import *
# 定义神经网络，神经网络的输入为空间坐标，输出程函量
dead = AutoDecoderNet(coordinate_size=3, hidden_size=hidden_size,
                        output_size=1, depth=depth, latent_vectors_num=data_num, latent_size=latent_size).cuda()
dead.train()

# 计算参数数量
total_params = sum(p.numel() for p in dead.parameters() if p.requires_grad)
print(f'Total parameters: {total_params/1e6}M')